#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""Utilities for maximizing acquisition functions."""

from __future__ import annotations

from typing import Mapping
from warnings import warn

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.exceptions.errors import BotorchError, BotorchTensorDimensionError
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.gpytorch import ModelListGPyTorchModel
from torch import Tensor


def columnwise_clamp(
    X: Tensor,
    lower: float | Tensor | None = None,
    upper: float | Tensor | None = None,
    raise_on_violation: bool = False,
) -> Tensor:
    r"""Clamp values of a Tensor in column-wise fashion (with support for t-batches).

    This function is useful in conjunction with optimizers from the torch.optim
    package, which don't natively handle constraints. If you apply this after
    a gradient step you can be fancy and call it "projected gradient descent".
    This funtion is also useful for post-processing candidates generated by the
    scipy optimizer that satisfy bounds only up to numerical accuracy.

    Args:
        X: The `b x n x d` input tensor. If 2-dimensional, `b` is assumed to be 1.
        lower: The column-wise lower bounds. If scalar, apply bound to all columns.
        upper: The column-wise upper bounds. If scalar, apply bound to all columns.
        raise_on_violation: If `True`, raise an exception when the elments in `X`
            are out of the specified bounds (up to numerical accuracy). This is
            useful for post-processing candidates generated by optimizers that
            satisfy imposed bounds only up to numerical accuracy.

    Returns:
        The clamped tensor.
    """
    if lower is None and upper is None:
        return X

    if lower is not None:
        lower = torch.as_tensor(lower).expand_as(X).to(X)

    if upper is not None:
        upper = torch.as_tensor(upper).expand_as(X).to(X)
        if lower is not None and (lower > upper).any():
            raise ValueError("Lower bounds cannot exceed upper bounds.")

    out = X.clamp(lower, upper)
    if raise_on_violation and not X.allclose(out):
        raise BotorchError(
            "Original value(s) are out of bounds: " f"{out=}, {X=}, {lower=}, {upper=}."
        )

    return out


def fix_features(
    X: Tensor,
    fixed_features: Mapping[int, float | Tensor] | None = None,
    replace_current_value: bool = True,
) -> Tensor:
    r"""Fix feature values in a Tensor.

    The fixed features will have zero gradient in downstream calculations.

    Args:
        X: input Tensor with shape `b x q x (reduced_p | p)`, where `reduced_p` is
            the number of features not fixed to a constant value and p is the full,
            use `p` if `replace_current_value` is True, and `reduced_p` otherwise.
        fixed_features: A mapping with keys as column indices and values
            equal to what the feature should be set to in `X`. Keys should be in the
            range `[0, p - 1]`.
            If a tensor is passed as value, it has to either have shape `b x q` or
            `b`, in which case the same value is used across the q dimension.
        replace_current_value: If True, replace the specified indexes, otherwise
            the indices are inserted.

    Returns:
        The tensor X with fixed features.
    """
    if fixed_features is None:
        return X

    if replace_current_value:
        X = X[..., [i for i in range(X.shape[-1]) if i not in fixed_features]]

    new_X = torch.zeros(
        *X.shape[:-1],
        (X.shape[-1] + len(fixed_features)),
        dtype=X.dtype,
        device=X.device,
    )

    filtered_index = 0
    for index in range(new_X.shape[-1]):
        if index in fixed_features:
            value = fixed_features[index]
            if torch.is_tensor(value) and value.ndim > 0:
                if X.ndim != 3:
                    raise BotorchTensorDimensionError(
                        "X must be a 3-dimensional tensor, as value is a tensor."
                        f"X.shape = {X.shape}, value.shape = {value.shape}."
                    )
                _b, q, _reduced_p = X.shape
                if value.ndim == 1:
                    # Repeat values across the q dimension.
                    value = value.unsqueeze(-1).repeat(1, q)
            else:
                value = torch.full_like(new_X[..., index], value)
            new_X[..., index] = value
        else:
            new_X[..., index] = X[..., filtered_index]
            filtered_index += 1

    return new_X


def get_X_baseline(acq_function: AcquisitionFunction) -> Tensor | None:
    r"""Extract X_baseline from an acquisition function.

    This tries to find the baseline set of points. First, this checks if the
    acquisition function has an `X_baseline` attribute. If it does not,
    then this method attempts to use the model's `train_inputs` as `X_baseline`.

    Args:
        acq_function: The acquisition function.

    Returns
        An optional `n x d`-dim tensor of baseline points. This is None if no
            baseline points are found.
    """
    try:
        X = acq_function.X_baseline
        # if there are no baseline points, use training points
        if X.shape[0] == 0:
            raise BotorchError
    except (BotorchError, AttributeError):
        try:
            # some acquisition functions do not have a model attribute
            # e.g. FixedFeatureAcquisitionFunction
            model = acq_function.model
        except AttributeError:
            warn("Failed to extract X_baseline.", BotorchWarning)
            return
        try:
            # Make sure we get the original train inputs.
            m = model.models[0] if isinstance(model, ModelListGPyTorchModel) else model
            if m._has_transformed_inputs:
                X = m._original_train_inputs
            else:
                X = m.train_inputs[0]
        except (BotorchError, AttributeError):
            warn("Failed to extract X_baseline.", BotorchWarning)
            return
    # just use one batch
    while X.ndim > 2:
        X = X[0]
    return X
